from cleanlab_tlm import TLM

from lfx.custom import Component
from lfx.io import (
    DropdownInput,
    MessageTextInput,
    Output,
    SecretStrInput,
)
from lfx.schema.message import Message


class CleanlabEvaluator(Component):
    """A component that evaluates the trustworthiness of LLM responses using Cleanlab.

    This component takes a prompt and response pair, along with optional system instructions,
    and uses Cleanlab's evaluation algorithms to generate a trust score and explanation.

    Inputs:
        - system_prompt (MessageTextInput): Optional system-level instructions prepended to the user prompt.
        - prompt (MessageTextInput): The user's prompt or query sent to the LLM.
        - response (MessageTextInput): The response generated by the LLM to be evaluated. This should come from the
          LLM component, i.e. OpenAI, Gemini, etc.
        - api_key (SecretStrInput): Your Cleanlab API key.
        - model (DropdownInput): The model used by Cleanlab to evaluate the response (can differ from the
          generation model).
        - quality_preset (DropdownInput): Tradeoff setting for accuracy vs. speed and cost. Higher presets are
          slower but more accurate.

    Outputs:
        - response_passthrough (Message): The original response, passed through for downstream use.
        - score (number): A float between 0 and 1 indicating Cleanlab's trustworthiness score for the response.
        - explanation (Message): A textual explanation of why the response received its score.

    This component works well in conjunction with the CleanlabRemediator to create a complete trust evaluation
    and remediation pipeline.

    More details on the evaluation metrics can be found here: https://help.cleanlab.ai/tlm/tutorials/tlm/
    """

    display_name = "Cleanlab Evaluator"
    description = "Evaluates any LLM response using Cleanlab and outputs trust score and explanation."
    icon = "Cleanlab"
    name = "CleanlabEvaluator"

    inputs = [
        MessageTextInput(
            name="system_prompt",
            display_name="System Message",
            info="System-level instructions prepended to the user query.",
            value="",
        ),
        MessageTextInput(
            name="prompt",
            display_name="Prompt",
            info="The user's query to the model.",
            required=True,
        ),
        MessageTextInput(
            name="response",
            display_name="Response",
            info="The response to the user's query.",
            required=True,
        ),
        SecretStrInput(
            name="api_key",
            display_name="Cleanlab API Key",
            info="Your Cleanlab API key.",
            required=True,
        ),
        DropdownInput(
            name="model",
            display_name="Cleanlab Evaluation Model",
            options=[
                "gpt-4.1",
                "gpt-4.1-mini",
                "gpt-4.1-nano",
                "o4-mini",
                "o3",
                "gpt-4.5-preview",
                "gpt-4o-mini",
                "gpt-4o",
                "o3-mini",
                "o1",
                "o1-mini",
                "gpt-4",
                "gpt-3.5-turbo-16k",
                "claude-3.7-sonnet",
                "claude-3.5-sonnet-v2",
                "claude-3.5-sonnet",
                "claude-3.5-haiku",
                "claude-3-haiku",
                "nova-micro",
                "nova-lite",
                "nova-pro",
            ],
            info="The model Cleanlab uses to evaluate the response. This does NOT need to be the same model that "
            "generated the response.",
            value="gpt-4o-mini",
            required=True,
            advanced=True,
        ),
        DropdownInput(
            name="quality_preset",
            display_name="Quality Preset",
            options=["base", "low", "medium", "high", "best"],
            value="medium",
            info="This determines the accuracy, latency, and cost of the evaluation. Higher quality is generally "
            "slower but more accurate.",
            required=True,
            advanced=True,
        ),
    ]

    outputs = [
        Output(
            display_name="Response",
            name="response_passthrough",
            method="pass_response",
            types=["Message"],
        ),
        Output(display_name="Trust Score", name="score", method="get_score", types=["number"]),
        Output(
            display_name="Explanation",
            name="explanation",
            method="get_explanation",
            types=["Message"],
        ),
    ]

    def _evaluate_once(self):
        if not hasattr(self, "_cached_result"):
            full_prompt = f"{self.system_prompt}\n\n{self.prompt}" if self.system_prompt else self.prompt
            tlm = TLM(
                api_key=self.api_key,
                options={"log": ["explanation"], "model": self.model},
                quality_preset=self.quality_preset,
            )
            self._cached_result = tlm.get_trustworthiness_score(full_prompt, self.response)
        return self._cached_result

    def get_score(self) -> float:
        result = self._evaluate_once()
        score = result.get("trustworthiness_score", 0.0)
        self.status = f"Trust score: {score:.2f}"
        return score

    def get_explanation(self) -> Message:
        result = self._evaluate_once()
        explanation = result.get("log", {}).get("explanation", "No explanation returned.")
        return Message(text=explanation)

    def pass_response(self) -> Message:
        self.status = "Passing through response."
        return Message(text=self.response)
